{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch.nn.init\n",
    "## torch.nn.init.calculate_gain\n",
    "torch.nn.init.calculate_gain(nonlinearity,param=None)  \n",
    "对于给定的非线性函数，返回推荐的增益值。  \n",
    "参数：  \n",
    "- nonlinearity - 非线性函数（ nn.functional 名称）\n",
    "- param - 非线性函数的可选参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.4141428569978354\n"
     ]
    }
   ],
   "source": [
    "gain = nn.init.calculate_gain('leaky_relu')\n",
    "print(gain)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.nn.init.uniform\n",
    "torch.nn.init.uniform(tensor, a=0, b=1)  \n",
    "从均匀分布U(a, b)中生成值，填充输入的张量或变量  \n",
    "参数：  \n",
    "- tensor - n维的torch.Tensor\n",
    "- a - 均匀分布的下界\n",
    "- b - 均匀分布的上界"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.5376, 0.3341, 0.1739, 0.8834, 0.9355],\n",
      "        [0.6246, 0.8807, 0.0532, 0.9090, 0.8394],\n",
      "        [0.6430, 0.7363, 0.4549, 0.3403, 0.7400]])\n"
     ]
    }
   ],
   "source": [
    "w = torch.Tensor(3,5)\n",
    "nn.init.uniform_(w)\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.nn.init.normal\n",
    "torch.nn.init.normal(tensor, mean=0, std=1)  \n",
    "从给定均值和标准差的正态分布N(mean, std)中生成值，填充输入的张量或变量。  \n",
    "参数：  \n",
    "- tensor – n维的torch.Tensor\n",
    "- mean – 正态分布的均值\n",
    "- std – 正态分布的标准差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-1.7107,  2.0186,  1.3307, -0.6326,  0.1447],\n",
      "        [-0.5574,  0.1932,  0.1790,  0.6890,  0.3451],\n",
      "        [ 0.0886, -0.3431,  0.7525, -1.4268,  0.2959]])\n"
     ]
    }
   ],
   "source": [
    "w = torch.Tensor(3,5)\n",
    "nn.init.normal_(w)\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.nn.init.constant\n",
    "torch.nn.init.constant(tensor, val)  \n",
    "用val的值填充输入的张量或变量  \n",
    "参数：  \n",
    "- tensor – n维的torch.Tensor或autograd.Variable\n",
    "- val – 用来填充张量的值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.2000, 1.2000, 1.2000, 1.2000, 1.2000],\n",
      "        [1.2000, 1.2000, 1.2000, 1.2000, 1.2000],\n",
      "        [1.2000, 1.2000, 1.2000, 1.2000, 1.2000]])\n"
     ]
    }
   ],
   "source": [
    "w = torch.Tensor(3,5)\n",
    "nn.init.constant_(w,1.2)\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.nn.init.eye\n",
    "torch.nn.init.eye(tensor)  \n",
    "用单位矩阵来填充2维输入张量或变量。在线性层尽可能多的保存输入特性。  \n",
    "参数：  \n",
    "- tensor – 2维的torch.Tensor或autograd.Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 0., 0., 0., 0.],\n",
      "        [0., 1., 0., 0., 0.],\n",
      "        [0., 0., 1., 0., 0.]])\n"
     ]
    }
   ],
   "source": [
    "w = torch.Tensor(3,5)\n",
    "nn.init.eye_(w)\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.nn.init.dirac\n",
    "torch.nn.init.dirac(tensor)  \n",
    "用dirac 函数来填充{3, 4, 5}维输入张量或变量。在卷积层尽可能多的保存输入通道特性。  \n",
    "参数：  \n",
    "- tensor – {3, 4, 5}维的torch.Tensor或autograd.Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 1., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 1., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 1., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]],\n",
      "\n",
      "         [[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]]])\n"
     ]
    }
   ],
   "source": [
    "w = torch.Tensor(3,16,5,5)\n",
    "nn.init.dirac_(w)\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.nn.init.xavier_uniform\n",
    "torch.nn.init.xavier_uniform(tensor, gain=1)  \n",
    "根据Glorot, X.和Bengio, Y.在“Understanding the dif×culty of training deep feedforward neural networks”中描述的方法，用一个均匀分布生成值，填充输入的张量或变量。结果张量中的值采样自$U(-a, a)$，其中$a= gain * \\sqrt{ 2/(fan_{in} + fan_{out})}* \\sqrt{3}$。该方法也被称为Glorot initialisat。  \n",
    "参数：  \n",
    "- tensor – n维的torch.Tensor\n",
    "- gain - 可选的缩放因子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.3894, -0.5256,  0.2854,  0.3381, -0.3993],\n",
      "        [ 0.1500, -0.1928,  0.7586, -0.8315, -0.1885],\n",
      "        [-0.0313, -0.2178,  0.3334,  0.0576, -0.4263]])\n"
     ]
    }
   ],
   "source": [
    "w = torch.Tensor(3,5)\n",
    "nn.init.xavier_uniform_(w,gain=1)\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可用一个正态分布生成值，填充输入的张量或变量。结果张量中的值采样自均值为0，标准差为$gain * \\sqrt{2/(fan_{in} + fan_{out})}$的正态分布。也被称为Glorot initialisation。  \n",
    "参数：  \n",
    "- tensor – n维的torch.Tensor\n",
    "- gain - 可选的缩放因子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.5343, -0.4096,  0.3598,  0.2439],\n",
      "        [-1.3875,  0.0977, -0.1633, -0.1720],\n",
      "        [-0.4465, -0.0574,  0.3952,  0.0166]])\n"
     ]
    }
   ],
   "source": [
    "b=torch.Tensor(3,4)\n",
    "nn.init.xavier_normal_(b, gain=1)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.nn.init.kaiming_uniform\n",
    "torch.nn.init.kaiming_uniform(tensor, a=0, mode='fan_in')\n",
    "根据He, K等人于2015年在“Delving deep into recti×ers: Surpassing human-level performance on ImageNet classi×cation”中描述的方法，用一个均匀分布生成值，填充输入的张量或变量。结果张量中的值采样自$U(-bound, bound)$，其中$bound = \\sqrt{2/((1 + a^2) * fan_{in})} * \\sqrt{3}$。也被称为He initialisation.\n",
    "参数：  \n",
    "- tensor – n维的torch.Tensor或autograd.Variable\n",
    "- a -这层之后使用的recti×er的斜率系数（ReLU的默认值为0）\n",
    "- mode -可以为“fan_in”（默认）或“fan_out”。“fan_in”保留前向传播时权值方差的量级，“fan_out”保留反向传播时的量级。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.2284, -0.1150, -0.0251, -0.2666,  0.1354],\n",
      "        [ 0.5683, -0.4818, -0.0692, -0.1618, -0.9306],\n",
      "        [-0.8497, -1.2836,  0.5704, -0.5866, -0.1795]])\n"
     ]
    }
   ],
   "source": [
    "w = torch.Tensor(3,5)\n",
    "nn.init.kaiming_normal_(w,a=0,mode='fan_in')\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.nn.init.kaiming_normal\n",
    "torch.nn.init.kaiming_normal(tensor, a=0, mode='fan_in')\n",
    "根据He, K等人在“Delving deep into recti×ers: Surpassing human-level performance on ImageNet classi×cation”中描述的方法，用一个正态分布生成值，填充输入的张量或变量。结果张量中的值采样自均值为0，标准差为$\\sqrt{2/((1 + a^2) * fan_{in})}$的正态分布。  \n",
    "参数：  \n",
    "- tensor – n维的torch.Tensor或 autograd.Variable\n",
    "- a - 这层之后使用的recti×er的斜率系数（ReLU的默认值为0）\n",
    "- mode - 可以为“fan_in”（默认）或“fan_out”。“fan_in”保留前向传播时权值方差的量级，“fan_out”保留反向传播时的量级。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.4962,  0.2692,  0.1398,  0.0752,  0.9633],\n",
      "        [-1.1670, -0.0766,  0.4635,  0.0093,  0.0740],\n",
      "        [-0.8205,  1.0357,  0.2439, -0.8035,  0.5299]])\n"
     ]
    }
   ],
   "source": [
    "w = torch.Tensor(3, 5)\n",
    "nn.init.kaiming_normal_(w, mode='fan_out')\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.nn.init.orthogonal\n",
    "torch.nn.init.orthogonal(tensor, gain=1)  \n",
    "用（半）正交矩阵填充输入的张量或变量。输入张量必须至少是2维的，对于更高维度的张量，超出的维度会被展平，视作行等于第一个维度，列等于稀疏矩阵乘积的2维表示。其中非零元素生成自均值为0，标准差为std的正态分布。  \n",
    "参数：  \n",
    "- tensor – n维的torch.Tensor或 autograd.Variable，其中n>=2\n",
    "- gain -可选"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.7751,  0.2943, -0.4557,  0.3145, -0.0775],\n",
      "        [ 0.0891, -0.7561, -0.1506,  0.3916,  0.4944],\n",
      "        [-0.0637,  0.5475,  0.2297,  0.1705,  0.7838]])\n"
     ]
    }
   ],
   "source": [
    "w = torch.Tensor(3, 5)\n",
    "nn.init.orthogonal_(w)\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.nn.init.sparse\n",
    "torch.nn.init.sparse(tensor, sparsity, std=0.01)  \n",
    "将2维的输入张量或变量当做稀疏矩阵填充，其中非零元素根据一个均值为0，标准差为std的正态分布生成。 参考Martens, J.(2010)的 “Deep learning via Hessian-free optimization”.  \n",
    "参数：  \n",
    "- tensor – n维的torch.Tensor或autograd.Variable\n",
    "- sparsity - 每列中需要被设置成零的元素比例\n",
    "- std - 用于生成非零值的正态分布的标准差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.0067, -0.0205,  0.0000,  0.0166,  0.0155],\n",
      "        [-0.0116,  0.0048,  0.0067,  0.0004, -0.0143],\n",
      "        [ 0.0000,  0.0000, -0.0057,  0.0000,  0.0000]])\n"
     ]
    }
   ],
   "source": [
    "w = torch.Tensor(3, 5)\n",
    "nn.init.sparse_(w, sparsity=0.1)\n",
    "print(w)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
